import numpy as np
import torch
from torchvision.datasets import CIFAR10,CIFAR100
from torchvision import transforms
from torch.utils.data import random_split
import matplotlib.pyplot as plt
from backbone import ResNet_size
from model import OntoEncoder
from sklearn.neighbors import KNeighborsClassifier
from sklearn import metrics
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
from pytorch_metric_learning.utils.inference import CustomKNN
from pytorch_metric_learning.distances import BaseDistance
from pytorch_metric_learning.utils import loss_and_miner_utils as lmu
from sklearn.manifold import TSNE
from data_utils import *

output_dim = 2
num_classes = 10
cuda = torch.cuda.is_available()
device = torch.device('cuda:0' if cuda else 'cpu')

class CEDistance(BaseDistance):
    def __init__(self, parameter = [0.000001,1], **kwargs):
        super().__init__(**kwargs)
        self.parameter = parameter
        assert not self.is_inverted

    def compute_mat(self, query_emb, ref_emb):
        query_emb = query_emb.half()
        ref_emb = ref_emb.half()
        r,_ = query_emb.shape
        c,_ = ref_emb.shape
        print(1)
        dtype, device = query_emb.dtype, query_emb.device
        rows, cols = lmu.meshgrid_from_sizes(query_emb, ref_emb, dim=0)
        output = torch.zeros(rows.size(), dtype=dtype, device=device)
        rows, cols = rows.flatten(), cols.flatten()
        for i in range(r):
            for j in range(c):
                distances = self.pairwise_distance(query_emb[i], ref_emb[j])
                output[i, j] = distances
        return output

    def pairwise_distance(self, query_emb, ref_emb):
        return sum(- torch.log(1 + self.parameter[0] - abs(query_emb - ref_emb)**self.parameter[1]) + torch.log(torch.tensor(1 + self.parameter[0])))

def Metrics(y_real, y_pred):
    acc = metrics.accuracy_score(y_real, y_pred)
    prec = metrics.precision_score(y_real, y_pred, average='macro')
    rec = metrics.recall_score(y_real, y_pred, average='macro')
    f = metrics.f1_score(y_real, y_pred, average='macro')

    print("The average scores for all classes:")
    # Calculate metrics for each label, and find their unweighted mean. does not take label imbalance into account.
    print("\nAccuracy:  {:.2f}%".format(acc * 100))  # (TP+TN)/Total / number of classes
    print("Precision: {:.2f}%".format(prec * 100))  # TP/(TP+FP) / number of classes
    print("Recall:    {:.2f}%".format(rec * 100))  # TP/(TP+FN) / number of classes
    print("F-measure: {:.2f}%".format(f * 100))  # 2 * (prec*rec)/(prec+rec) / number of classes

    print("\nThe scores for each class:")
    precision, recall, fscore, support = metrics.precision_recall_fscore_support(y_real, y_pred)

    print("\n|    Label    |  Precision |  Recall  | F1-Score | Support")
    print("|-------------|------------|----------|----------|---------")
    for i in range(num_classes):
        print(
            f"| {classes[i]:<11} |  {precision[i] * 100:<7.2f}%  | {recall[i] * 100:<7.2f}% |   {fscore[i]:<4.2f}   | {support[i]}")

    return rec


color = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan']
marker = ['.','+','x','1','^','s','p','*','d','X']
classes = ['Plane', 'Car', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']


def Ploting2D(embeddings_plot,labels_plot,tit="default",x_axis="X",y_axis="Y"):
    ax = plt.figure().add_subplot(111)
    for i in range(num_classes):
        index = labels_plot == i
        plt.scatter(embeddings_plot[0, index], embeddings_plot[1, index], s=3, marker='.', c=color[i], label=classes[i])
    ax.legend(loc='best', title="Labels", markerscale=5.0)

    # add grid
    plt.grid(True,linestyle='--')

    # add title
    plt.title(tit)
    plt.tight_layout()

    # add x,y axes labels
    plt.xlabel(x_axis)
    plt.ylabel(y_axis)
    plt.savefig('./sample/redution.jpg')


def Metrics(y_real, y_pred):
    acc = metrics.accuracy_score(y_real, y_pred)
    prec = metrics.precision_score(y_real, y_pred, average='macro')
    rec = metrics.recall_score(y_real, y_pred, average='macro')
    f = metrics.f1_score(y_real, y_pred, average='macro')

    print("The average scores for all classes:")
    # Calculate metrics for each label, and find their unweighted mean. does not take label imbalance into account.
    print("\nAccuracy:  {:.2f}%".format(acc * 100))  # (TP+TN)/Total / number of classes
    print("Precision: {:.2f}%".format(prec * 100))  # TP/(TP+FP) / number of classes
    print("Recall:    {:.2f}%".format(rec * 100))  # TP/(TP+FN) / number of classes
    print("F-measure: {:.2f}%".format(f * 100))  # 2 * (prec*rec)/(prec+rec) / number of classes

    print("\nThe scores for each class:")
    precision, recall, fscore, support = metrics.precision_recall_fscore_support(y_real, y_pred)

    print("\n|    Label    |  Precision |  Recall  | F1-Score | Support")
    print("|-------------|------------|----------|----------|---------")
    for i in range(num_classes):
        print(
            f"| {classes[i]:<11} |  {precision[i] * 100:<7.2f}%  | {recall[i] * 100:<7.2f}% |   {fscore[i]:<4.2f}   | {support[i]}")

    return rec

def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]

    return data.to(device, non_blocking=True)

def test(dataset = None):
    parameter = [0.0000001, 0.1]
    if dataset == 'CIFAR10':
        n_class = 10
        mean, std = (0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)#CIFAR10
        preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std) ])
        test_dataset = CIFAR10(root='./datasets/', train=False, transform=preprocess, download='False')
        train_dataset = CIFAR10(root='./datasets/', train=True, transform=preprocess, download='False')
        
        model = to_device(OntoEncoder(n_class = n_class,loss='crossentropy'),device)
        model.load_state_dict(torch.load('checkpoint/crossentropy_resnet18_CIFAR10_0.99.pth'))

    elif dataset == 'CIFAR100':
        n_class = 100
        mean, std = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343), (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)#CIFAR100
        preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std) ])
        test_dataset = CIFAR100(root='./datasets/', train=False, transform=preprocess, download='False')
        train_dataset = CIFAR100(root='./datasets/', train=True, transform=preprocess, download='False')
        model = to_device(OntoEncoder(n_class = n_class,loss='crossentropy'),device)
        model.load_state_dict(torch.load('checkpoint/crossentropy_resnet18_CIFAR100.pth'))

    elif dataset == 'CUB':
        n_class = 200
        mean, std = [0.4862169586881995, 0.4998156522834164, 0.4311430419332438], [0.23264268069040475, 0.22781080253662814, 0.26667253517177186]
        preprocess = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomCrop(224, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225]),
        ])
        train_dataset = CUB(path='./datasets/CUB_200_2011', train=True, transform=preprocess)
        test_dataset = CUB(path='./datasets/CUB_200_2011', train=False, transform=preprocess)
    knn = KNeighborsClassifier(n_neighbors=1, metric = lambda x, y: sum(- np.log(1 + parameter[0] - abs(x - y )**parameter[1]) + np.log(1 + parameter[0])))
    # knn = KNeighborsClassifier(n_neighbors=1, metric = lambda x, y: np.sqrt(sum(abs(x - y)**2))) # a ablation study

   
    
    reference_embeddings, reference_labels = model.extract_embedding(train_dataset)
    query_embeddings, query_labels = model.extract_embedding(test_dataset)
    print(query_embeddings[1])
    print(query_labels[1])
    knn.fit(reference_embeddings, reference_labels)
    y_pred = knn.predict(query_embeddings)

    recall = Metrics(query_labels, y_pred)
    print(recall)
    knn_func = CustomKNN(CEDistance(parameter))
    calculator = AccuracyCalculator( exclude=("AMI","mean_average_precision"),
                    avg_of_avgs=False,
                    k="max_bin_count",
                    label_comparison_fn=None,
                    knn_func=knn_func)

    acc_dict = calculator.get_accuracy(query_embeddings, query_labels, reference_embeddings, reference_labels)
    
    print("\nNMI: ", acc_dict["NMI"]*100)
    print("p@1: ", acc_dict["precision_at_1"]*100)
    print("RP: ", acc_dict["r_precision"]*100)
    print("MAP@R: ", acc_dict["mean_average_precision_at_r"]*100)

    torch.cuda.empty_cache()
    # history = []


if __name__ =='__main__':
    test(dataset = 'CIFAR100')





